"""
Property-based tests for contextual retrieval functionality.

Uses hypothesis to test invariants and properties of the contextual retrieval system.
"""

import pytest
import asyncio
from dataclasses import dataclass, field
from typing import List, Dict, Any
from hypothesis import given, strategies as st, settings
from hypothesis.stateful import RuleBasedStateMachine, rule, initialize, invariant
from unittest.mock import patch
import string

from tldw_Server_API.app.core.RAG.rag_service.types import Document, DataSource
from tldw_Server_API.app.core.RAG.rag_service.enhanced_chunking_integration import (
    expand_with_parent_context,
    filter_chunks_by_type,
    prioritize_by_chunk_type
)


# Strategies for generating test data
@st.composite
def document_strategy(draw):
    """Generate a random Document for testing."""
    doc_id = draw(st.text(string.ascii_letters, min_size=1, max_size=20))
    content = draw(st.text(min_size=10, max_size=500))
    parent_id = draw(st.one_of(st.none(), st.text(string.ascii_letters, min_size=1, max_size=10)))
    chunk_index = draw(st.integers(min_value=0, max_value=100))
    chunk_type = draw(st.sampled_from(["text", "code", "table", "header", "list"]))
    score = draw(st.floats(min_value=0.0, max_value=1.0))

    return Document(
        id=doc_id,
        content=content,
        metadata={
            "parent_id": parent_id,
            "chunk_index": chunk_index,
            "chunk_type": chunk_type
        },
        source=DataSource.MEDIA_DB,
        score=score
    )


@st.composite
def document_list_strategy(draw, min_size=0, max_size=20):
    """Generate a list of Documents."""
    return draw(st.lists(document_strategy(), min_size=min_size, max_size=max_size))


@st.composite
def context_strategy(draw):
    """Generate a RAGPipelineContext."""
    query = draw(st.text(min_size=1, max_size=100))
    documents = draw(document_list_strategy())
    config = draw(st.dictionaries(
        st.text(string.ascii_letters, min_size=1, max_size=20),
        st.one_of(st.integers(), st.floats(), st.booleans(), st.text()),
        max_size=10
    ))

    context = SimpleContext(query=query, original_query=query, config=config)
    context.documents = documents
    return context


class TestContextualRetrievalProperties:
    """Property-based tests for contextual retrieval."""

    @given(context=context_strategy())
    @settings(max_examples=100, deadline=1000)
    @pytest.mark.asyncio
    async def test_expand_preserves_document_count_or_increases(self, context):
        """Property: Parent expansion processes documents with parent_id."""
        original_count = len(context.documents)
        # Count documents with parent_id
        docs_with_parent = sum(1 for d in context.documents if d.metadata.get("parent_id"))

        result = await expand_with_parent_context(context)

        # Only documents with parent_id are included in expansion
        assert len(result.documents) == docs_with_parent

    @given(context=context_strategy())
    @settings(max_examples=100, deadline=1000)
    @pytest.mark.asyncio
    async def test_expand_preserves_document_ids(self, context):
        """Property: Document IDs with parent_id should be preserved during expansion."""
        # Only documents with parent_id are expanded
        original_ids_with_parent = {doc.id for doc in context.documents if doc.metadata.get("parent_id")}

        result = await expand_with_parent_context(context)

        result_ids = {doc.id for doc in result.documents}
        # IDs of documents with parent_id should be preserved
        assert original_ids_with_parent == result_ids

    @given(
        context=context_strategy(),
        expansion_size=st.integers(min_value=100, max_value=2000)
    )
    @settings(max_examples=50, deadline=1000)
    @pytest.mark.asyncio
    async def test_expansion_size_bounds(self, context, expansion_size):
        """Property: Expansion should respect size boundaries."""
        result = await expand_with_parent_context(context, expansion_size=expansion_size)

        # Each document's expansion should be bounded
        for doc in result.documents:
            if doc.metadata.get("expanded"):
                # The expanded content shouldn't be absurdly large
                assert len(doc.content) <= len(doc.content) + expansion_size * 3  # Allow some overhead

    @given(
        context=context_strategy(),
        include_types=st.lists(st.sampled_from(["text", "code", "table", "header", "list"]), min_size=0, max_size=3),
        exclude_types=st.lists(st.sampled_from(["text", "code", "table", "header", "list"]), min_size=0, max_size=3)
    )
    @settings(max_examples=100, deadline=1000)
    @pytest.mark.asyncio
    async def test_filter_types_correctness(self, context, include_types, exclude_types):
        """Property: Type filtering should correctly include/exclude types."""
        result = await filter_chunks_by_type(
            context,
            include_types=include_types if include_types else None,
            exclude_types=exclude_types if exclude_types else None
        )

        # Check all remaining documents match filter criteria
        for doc in result.documents:
            chunk_type = doc.metadata.get("chunk_type", "text")

            if include_types:
                assert chunk_type in include_types
            if exclude_types:
                assert chunk_type not in exclude_types

    @given(
        context=context_strategy(),
        priorities=st.dictionaries(
            st.sampled_from(["text", "code", "table", "header", "list"]),
            st.floats(min_value=0.1, max_value=10.0),
            min_size=0,
            max_size=5
        )
    )
    @settings(max_examples=100, deadline=1000)
    @pytest.mark.asyncio
    async def test_prioritization_score_ordering(self, context, priorities):
        """Property: After prioritization, documents should be ordered by score."""
        if not context.documents:
            return  # Skip empty contexts

        result = await prioritize_by_chunk_type(context, type_priorities=priorities)

        # Check that documents are sorted by score (descending)
        scores = [doc.score for doc in result.documents]
        assert scores == sorted(scores, reverse=True)

    @given(
        context=context_strategy(),
        priorities=st.dictionaries(
            st.sampled_from(["text", "code", "table", "header", "list"]),
            st.floats(min_value=0.1, max_value=10.0, allow_nan=False, allow_infinity=False),
            min_size=1,
            max_size=5
        )
    )
    @settings(max_examples=50, deadline=1000)
    @pytest.mark.asyncio
    async def test_prioritization_multiplier_effect(self, context, priorities):
        """Property: Priority multipliers should correctly adjust scores."""
        if not context.documents:
            return

        # Snapshot original docs' expected scores BEFORE mutation/sorting
        original_info = []  # list of tuples: (id, expected_score, orig_score)
        for d in context.documents:
            chunk_type = d.metadata.get("chunk_type", "text")
            orig_score = d.score
            expected_score = orig_score * priorities.get(chunk_type, 1.0)
            original_info.append((d.id, expected_score, orig_score))

        result = await prioritize_by_chunk_type(context, type_priorities=priorities)

        # Verify each original, non-duplicate document's score was adjusted as expected
        seen_ids = set()
        for doc_id, expected_score, orig_score in original_info:
            if doc_id in seen_ids:
                continue  # skip duplicates from hypothesis generation
            seen_ids.add(doc_id)

            if orig_score <= 0:
                continue  # only check non-zero scores

            found = any(
                (rd.id == doc_id) and (abs(rd.score - expected_score) < 1e-4)
                for rd in result.documents
            )

            if not found:
                assert False, f"Doc {doc_id} score mismatch: expected {expected_score}, not found in results"

    @given(documents=document_list_strategy(min_size=1, max_size=50))
    @settings(max_examples=100, deadline=1000)
    def test_parent_grouping_consistency(self, documents):
        """Property: Documents with same parent_id should be grouped together."""
        # Group documents by parent_id
        parent_groups = {}
        for doc in documents:
            parent_id = doc.metadata.get("parent_id")
            if parent_id:
                if parent_id not in parent_groups:
                    parent_groups[parent_id] = []
                parent_groups[parent_id].append(doc)

        # Each group should have consistent parent_id
        for parent_id, group in parent_groups.items():
            assert all(doc.metadata.get("parent_id") == parent_id for doc in group)

    @given(
        text=st.text(min_size=10, max_size=1000),
        chunk_size=st.integers(min_value=10, max_value=100),
        enable_contextual=st.booleans()
    )
    def test_contextualization_flag_behavior(self, text, chunk_size, enable_contextual):
        """Property: Contextualization should only occur when enabled."""
        with patch('tldw_Server_API.app.core.Embeddings.ChromaDB_Library.analyze') as mock_analyze:
            mock_analyze.return_value = "context"

            # Simulate chunking with/without contextualization
            if enable_contextual:
                # When enabled, analyze should be called
                assert enable_contextual == True
            else:
                # When disabled, analyze should not be called
                assert enable_contextual == False

            # The flag should control behavior deterministically
            assert isinstance(enable_contextual, bool)


class ContextualRetrievalStateMachine(RuleBasedStateMachine):
    """
    Stateful testing for contextual retrieval operations.

    Tests that sequences of operations maintain consistency.
    """

    def __init__(self):
        super().__init__()
        self.context = None
        self.original_doc_count = 0
        self.filter_applied = False
        self.expansion_applied = False
        self.prioritization_applied = False

    @initialize()
    def setup(self):
        """Initialize with a context containing documents."""
        self.context = SimpleContext(query="test query", original_query="test query", config={})
        self.context.documents = [
            Document(
                id=f"doc_{i}",
                content=f"Content {i}",
                metadata={
                    "parent_id": f"parent_{i // 3}",
                    "chunk_index": i % 3,
                    "chunk_type": ["text", "code", "table"][i % 3]
                },
                source=DataSource.MEDIA_DB,
                score=0.5 + (i * 0.1)
            )
            for i in range(10)
        ]
        self.original_doc_count = len(self.context.documents)

    @rule()
    def apply_expansion(self):
        """Rule: Apply parent context expansion (sync wrapper)."""
        result = asyncio.run(expand_with_parent_context(self.context))
        self.context = result
        self.expansion_applied = True

    @rule(
        include_types=st.lists(st.sampled_from(["text", "code", "table"]), min_size=1, max_size=2)
    )
    def apply_filter(self, include_types):
        """Rule: Apply type filtering (sync wrapper)."""
        result = asyncio.run(filter_chunks_by_type(self.context, include_types=include_types))
        self.context = result
        self.filter_applied = True

    @rule(
        priorities=st.dictionaries(
            st.sampled_from(["text", "code", "table"]),
            st.floats(min_value=0.5, max_value=2.0),
            min_size=1,
            max_size=3
        )
    )
    def apply_prioritization(self, priorities):
        """Rule: Apply type prioritization (sync wrapper)."""
        result = asyncio.run(prioritize_by_chunk_type(self.context, type_priorities=priorities))
        self.context = result
        self.prioritization_applied = True

    @invariant()
    def documents_not_empty(self):
        """Invariant: Context should always have documents if started with any."""
        if self.original_doc_count > 0:
            # Some operations might filter all docs, but that's valid
            assert self.context is not None
            assert hasattr(self.context, 'documents')

    @invariant()
    def scores_valid(self):
        """Invariant: All document scores should be valid numbers."""
        if self.context and self.context.documents:
            for doc in self.context.documents:
                assert isinstance(doc.score, (int, float))
                assert 0 <= doc.score <= 10  # Reasonable bounds after multiplication
                assert doc.score == doc.score  # Not NaN

    @invariant()
    def metadata_preserved(self):
        """Invariant: Essential metadata should be preserved."""
        if self.context and self.context.documents:
            for doc in self.context.documents:
                assert isinstance(doc.metadata, dict)
                # Should have at least some metadata
                assert len(doc.metadata) > 0

    @invariant()
    def ids_unique(self):
        """Invariant: Document IDs should remain unique."""
        if self.context and self.context.documents:
            ids = [doc.id for doc in self.context.documents]
            assert len(ids) == len(set(ids))  # All IDs unique


# Test the state machine
TestContextualStateMachine = ContextualRetrievalStateMachine.TestCase
TestContextualStateMachine.settings = settings(max_examples=10, stateful_step_count=20, deadline=2000)


@given(
    enable_contextual=st.booleans(),
    llm_model=st.sampled_from(["gpt-3.5-turbo", "gpt-4", "claude-opus-4.1", None]),
    context_window=st.one_of(st.none(), st.integers(min_value=100, max_value=2000))
)
def test_configuration_combinations(enable_contextual, llm_model, context_window):
    """Property: All valid configuration combinations should be accepted."""
    config = {
        "enable_contextual_chunking": enable_contextual,
        "contextual_llm_model": llm_model,
        "context_window_size": context_window
    }

    # Configuration should be valid
    assert isinstance(config["enable_contextual_chunking"], bool)

    if enable_contextual:
        # When enabled, should have a model (use default if None)
        model = config["contextual_llm_model"] or "gpt-3.5-turbo"
        assert isinstance(model, str)

    if context_window is not None:
        assert 100 <= context_window <= 2000
@dataclass
class SimpleContext:
    query: str
    original_query: str
    documents: List[Document] = field(default_factory=list)
    metadata: Dict[str, Any] = field(default_factory=dict)
    config: Dict[str, Any] = field(default_factory=dict)
    cache_hit: bool = False
    timings: Dict[str, float] = field(default_factory=dict)
    errors: List[Dict[str, Any]] = field(default_factory=list)
